/*
Copyright (C) 2021 by clash authors

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
*/

package protocol

import (
	"bytes"
	"encoding/binary"
	"hash/adler32"
	"hash/crc32"
	"math/rand"
	"net"

	"libcore/clash/common/pool"
	"libcore/clash/transport/ssr/tools"
)

func init() {
	register("auth_sha1_v4", newAuthSHA1V4, 7)
}

type authSHA1V4 struct {
	*Base
	*authData
	iv            []byte
	hasSentHeader bool
	rawTrans      bool
}

func newAuthSHA1V4(b *Base) Protocol {
	return &authSHA1V4{Base: b, authData: &authData{}}
}

func (a *authSHA1V4) StreamConn(c net.Conn, iv []byte) net.Conn {
	p := &authSHA1V4{Base: a.Base, authData: a.next()}
	p.iv = iv
	return &Conn{Conn: c, Protocol: p}
}

func (a *authSHA1V4) PacketConn(c net.PacketConn) net.PacketConn {
	return c
}

func (a *authSHA1V4) Decode(dst, src *bytes.Buffer) error {
	if a.rawTrans {
		dst.ReadFrom(src)
		return nil
	}
	for src.Len() > 4 {
		if uint16(crc32.ChecksumIEEE(src.Bytes()[:2])&0xffff) != binary.LittleEndian.Uint16(src.Bytes()[2:4]) {
			src.Reset()
			return errAuthSHA1V4CRC32Error
		}

		length := int(binary.BigEndian.Uint16(src.Bytes()[:2]))
		if length >= 8192 || length < 7 {
			a.rawTrans = true
			src.Reset()
			return errAuthSHA1V4LengthError
		}
		if length > src.Len() {
			break
		}

		if adler32.Checksum(src.Bytes()[:length-4]) != binary.LittleEndian.Uint32(src.Bytes()[length-4:length]) {
			a.rawTrans = true
			src.Reset()
			return errAuthSHA1V4Adler32Error
		}

		pos := int(src.Bytes()[4])
		if pos < 255 {
			pos += 4
		} else {
			pos = int(binary.BigEndian.Uint16(src.Bytes()[5:7])) + 4
		}
		dst.Write(src.Bytes()[pos : length-4])
		src.Next(length)
	}
	return nil
}

func (a *authSHA1V4) Encode(buf *bytes.Buffer, b []byte) error {
	if !a.hasSentHeader {
		dataLength := getDataLength(b)

		a.packAuthData(buf, b[:dataLength])
		b = b[dataLength:]

		a.hasSentHeader = true
	}
	for len(b) > 8100 {
		a.packData(buf, b[:8100])
		b = b[8100:]
	}
	if len(b) > 0 {
		a.packData(buf, b)
	}

	return nil
}

func (a *authSHA1V4) DecodePacket(b []byte) ([]byte, error) { return b, nil }

func (a *authSHA1V4) EncodePacket(buf *bytes.Buffer, b []byte) error {
	buf.Write(b)
	return nil
}

func (a *authSHA1V4) packData(poolBuf *bytes.Buffer, data []byte) {
	dataLength := len(data)
	randDataLength := a.getRandDataLength(dataLength)
	/*
		2:	uint16 BigEndian packedDataLength
		2:	uint16 LittleEndian crc32Data & 0xffff
		3:	maxRandDataLengthPrefix (min:1)
		4:	adler32Data
	*/
	packedDataLength := 2 + 2 + 3 + randDataLength + dataLength + 4
	if randDataLength < 128 {
		packedDataLength -= 2
	}

	binary.Write(poolBuf, binary.BigEndian, uint16(packedDataLength))
	binary.Write(poolBuf, binary.LittleEndian, uint16(crc32.ChecksumIEEE(poolBuf.Bytes()[poolBuf.Len()-2:])&0xffff))
	a.packRandData(poolBuf, randDataLength)
	poolBuf.Write(data)
	binary.Write(poolBuf, binary.LittleEndian, adler32.Checksum(poolBuf.Bytes()[poolBuf.Len()-packedDataLength+4:]))
}

func (a *authSHA1V4) packAuthData(poolBuf *bytes.Buffer, data []byte) {
	dataLength := len(data)
	randDataLength := a.getRandDataLength(12 + dataLength)
	/*
		2:	uint16 BigEndian packedAuthDataLength
		4:	uint32 LittleEndian crc32Data
		3:	maxRandDataLengthPrefix (min: 1)
		12:	authDataLength
		10:	hmacSHA1DataLength
	*/
	packedAuthDataLength := 2 + 4 + 3 + randDataLength + 12 + dataLength + 10
	if randDataLength < 128 {
		packedAuthDataLength -= 2
	}

	salt := []byte("auth_sha1_v4")
	crcData := pool.Get(len(salt) + len(a.Key) + 2)
	defer pool.Put(crcData)
	binary.BigEndian.PutUint16(crcData, uint16(packedAuthDataLength))
	copy(crcData[2:], salt)
	copy(crcData[2+len(salt):], a.Key)

	key := pool.Get(len(a.iv) + len(a.Key))
	defer pool.Put(key)
	copy(key, a.iv)
	copy(key[len(a.iv):], a.Key)

	poolBuf.Write(crcData[:2])
	binary.Write(poolBuf, binary.LittleEndian, crc32.ChecksumIEEE(crcData))
	a.packRandData(poolBuf, randDataLength)
	a.putAuthData(poolBuf)
	poolBuf.Write(data)
	poolBuf.Write(tools.HmacSHA1(key, poolBuf.Bytes()[poolBuf.Len()-packedAuthDataLength+10:])[:10])
}

func (a *authSHA1V4) packRandData(poolBuf *bytes.Buffer, size int) {
	if size < 128 {
		poolBuf.WriteByte(byte(size + 1))
		tools.AppendRandBytes(poolBuf, size)
		return
	}
	poolBuf.WriteByte(255)
	binary.Write(poolBuf, binary.BigEndian, uint16(size+3))
	tools.AppendRandBytes(poolBuf, size)
}

func (a *authSHA1V4) getRandDataLength(size int) int {
	if size > 1200 {
		return 0
	}
	if size > 400 {
		return rand.Intn(256)
	}
	return rand.Intn(512)
}
